{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from collections import deque\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 设置参数\n",
    "gamma = 0.98\n",
    "epsilon = 0.2\n",
    "epsilon_decay = 0.98\n",
    "num_episodes = 1000\n",
    "batch_size = 128"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "6ba0f6849938e5b6"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class QNet(torch.nn.Module):\n",
    "    def __init__(self, input_dims, output_dims):\n",
    "        super().__init__()\n",
    "        self.fc1 = torch.nn.Linear(input_dims, 8)\n",
    "        self.fc2 = torch.nn.Linear(8, 8)\n",
    "        self.fc3_1 = torch.nn.Linear(8, output_dims)\n",
    "        self.fc3_2 = torch.nn.Linear(8, output_dims)\n",
    "        self.relu = torch.nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.relu(self.fc2(x))\n",
    "        avalue = self.fc3_1(x)\n",
    "        svalue = self.fc3_2(x)\n",
    "        sub = avalue - torch.mean(avalue)\n",
    "        return svalue + sub\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "fa8dec3e278013f4"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 动作策略\n",
    "def get_action(q_net, state):\n",
    "    with torch.no_grad():\n",
    "        action_q_value = q_net(state)\n",
    "    if np.random.rand() <= epsilon:\n",
    "        action = np.argmax(action_q_value + np.random.randn(1, 2) * epsilon)\n",
    "    else:\n",
    "        action = np.argmax(action_q_value)\n",
    "    return action"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c77e7e0aebedf5e5"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 更新策略\n",
    "def update_net(memory, q_net, target_net):\n",
    "    data = random.sample(memory, batch_size)\n",
    "    s = np.array([i[0] for i in data])\n",
    "    ns = np.array([i[2] for i in data])\n",
    "    with torch.no_grad():\n",
    "        action_q_value = q_net(torch.tensor(s))\n",
    "        next_action_q_value = q_net(torch.tensor(ns))\n",
    "        target_q_value = target_net(torch.tensor(ns))\n",
    "    next_action = np.argmax(next_action_q_value, axis=1)\n",
    "    for i, (_, action, _, r, done) in enumerate(data):\n",
    "        if done:\n",
    "            action_q_value[i][action] = torch.tensor(r)\n",
    "        else:\n",
    "            target = r + gamma * target_q_value[i, next_action[i]]\n",
    "            action_q_value[i][action] = torch.tensor(target)\n",
    "    optimizer.zero_grad()\n",
    "    _action_q_value = q_net(torch.tensor(s))\n",
    "    loss = mse(_action_q_value, torch.tensor(action_q_value))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return q_net, loss"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "b765dc55c34f365c"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def update_target_net(q_net, target_net):\n",
    "    with torch.no_grad():\n",
    "        for param_q, param_target in zip(q_net.parameters(), target_net.parameters()):\n",
    "            param_target.data.copy_(param_q.data)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e1da1e6b483fe582"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 程序设置\n",
    "q_net = QNet(4, 2)\n",
    "target_net = QNet(4, 2)\n",
    "mse = torch.nn.MSELoss()\n",
    "memory = deque(maxlen=5000)\n",
    "optimizer = optim.Adam(q_net.parameters(), lr=0.001)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "29f8d50986aa7232"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 主程序\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "\n",
    "for i in range(num_episodes):\n",
    "    s, _ = env.reset()\n",
    "    total_reward = 0\n",
    "    total_loss = []\n",
    "    loss = 0\n",
    "    step = 0\n",
    "    epsilon = epsilon * epsilon_decay\n",
    "    while total_reward < 3000:\n",
    "        s = torch.tensor(s)\n",
    "        action = get_action(q_net, s).detach().numpy()\n",
    "        ns, r, done, _, _ = env.step(action)\n",
    "        step += 1\n",
    "        if done:\n",
    "            r = -2\n",
    "        else:\n",
    "            r = r / 10\n",
    "        total_reward += r\n",
    "        memory.append([s, action, ns, r, done])\n",
    "        s = ns\n",
    "        if len(memory) > batch_size:\n",
    "            q_net, loss = update_net(memory, q_net, target_net)\n",
    "            if (step+1)%5 ==0:\n",
    "                update_target_net(q_net, target_net)\n",
    "            total_loss.append(loss.detach().numpy())\n",
    "        if done:\n",
    "            print(\n",
    "                'epoch{},step:{},total_rewards:{},ave_loss:{}'.format(i, step, total_reward, np.mean(total_loss)))\n",
    "            break\n",
    "    if total_reward == 3000:\n",
    "        print(\"训练完毕\")\n",
    "        break"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c6f1d360bd38b4f0"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "env = gym.make(\"CartPole-v1\", render_mode='human')\n",
    "s, _ = env.reset()\n",
    "step = 0\n",
    "while True:\n",
    "    s = torch.tensor(s)\n",
    "    action = get_action(q_net, s).detach().numpy()\n",
    "    s, r, done, _, _ = env.step(action)\n",
    "    step += 1\n",
    "    if done:\n",
    "        break\n",
    "print('step:{}'.format(step))\n",
    "env.close()\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "3e1231815d80b9ea"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "env.close()"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "44ce5aa3b7d4c773"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
